-
Notifications
You must be signed in to change notification settings - Fork 74
Combine for RaggedIterDomain #5716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: raggediterdomain_clone
Are you sure you want to change the base?
Conversation
|
!test |
|
Review updated until commit be0e2ea Description
|
| Relevant files | |||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||||
| Tests |
| ||||||||||
| Documentation |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Symbolic Extent Handling
|
Greptile SummaryThis PR introduces the Key changes:
The implementation follows nvFuser's design philosophy of trusting user-provided inputs (similar to arithmetic operations), with best-effort validation when feasible. Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant RaggedIterDomain
participant Combine as Combine Expr
participant IterDomain
participant Partition as Partition Expr
User->>RaggedIterDomain: combine(component, ragged)
RaggedIterDomain->>RaggedIterDomain: Validate component != null
RaggedIterDomain->>RaggedIterDomain: Validate ragged != null
RaggedIterDomain->>RaggedIterDomain: Validate component is not RaggedIterDomain
RaggedIterDomain->>RaggedIterDomain: Validate parallel types are Serial
RaggedIterDomain->>RaggedIterDomain: Validate iter types are Iteration
alt ragged has Partition definition
RaggedIterDomain->>Partition: Get expected component
Partition-->>RaggedIterDomain: Return component IterDomain
RaggedIterDomain->>RaggedIterDomain: Validate component matches expected
else No Partition definition
Note over RaggedIterDomain: Trust user (Option 3)
end
RaggedIterDomain->>RaggedIterDomain: Get extents from ragged
RaggedIterDomain->>RaggedIterDomain: Validate extents is 1D
RaggedIterDomain->>RaggedIterDomain: Create symbolic extent Val
RaggedIterDomain->>IterDomain: Create combined IterDomain
IterDomain-->>RaggedIterDomain: Return combined_id
RaggedIterDomain->>Combine: Create Combine expression
Combine->>Combine: addOutput(combined_id)
Combine->>Combine: addInput(component)
Combine->>Combine: addInput(ragged)
RaggedIterDomain-->>User: Return combined IterDomain
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 2 comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
|
!test |
This PR introduces the combine operation as discussed in the RaggedIterDomain design doc.
One design decision that I changed from the original design doc is about detecting and validating component iter domains. Previously, I was thinking about using the exact graph to find the corresponding component iter domain for a given ragged iter domain (e.g., #5550 (comment)). However, it won't work, for example, when a fusion is segmented and a segment does not have the corresponding
Partitionexpr for aRaggedIterDomain. For example, when a tensor is used as an input forasNested, followed by some other operations, if the fusion is segmented after some operations, the latter segment won't be able to see theasNestedand thePartitionoperations as they don't exist in the segment. This could be alleviated by providing an exact graph for the whole complete fusion, but more fundamentally, if a fusion has a nested tensor as an input, there doesn't seem to be any reasonable way to attach aPartitionexpr.See doc/dev/ragged_iter_domain_combine_design_doc.md for detailed discussions. At this moment, I decided to not worry too much about the validation and assume the correctness is guaranteed by the user.
Note that partitioning is still limited to 1D extents. Multi-dim offsets will be the next step of this series of RPs.